# -*- coding: utf-8 -*-
# @Time:  21:55
# @Author: tk
# @File：colossalai_config


colossalai_strategy = {
    "ddp": dict(
        name="ddp",
        broadcast_buffers= True,
        bucket_cap_mb = 25,
        find_unused_parameters = False,
        check_reduction = False,
        gradient_as_bucket_view = False,
        static_graph = False,
    ),
    "gemini":dict(
        name="gemini",
        chunk_config_dict = None,
        chunk_init_device= None,
        placement_policy = "static",
        shard_param_frac = 1.0,  # only for static placement
        offload_optim_frac = 0.0,  # only for static placement
        offload_param_frac = 0.0,  # only for static placement
        warmup_non_model_data_ratio = 0.8,  # only for auto placement
        steady_cuda_cap_ratio = 0.9,  # only for auto placement
        precision = "fp16",
        pin_memory = False,
        force_outputs_fp32 = False,
        strict_ddp_mode = False,
        search_range_m = 32,
        hidden_dim = None,
        min_chunk_size_m = 32,
        memstats = None,
        gpu_margin_mem_ratio = 0.0,
        initial_scale = 2 ** 16,
        min_scale = 1,
        growth_factor = 2,
        backoff_factor = 0.5,
        growth_interval = 1000,
        hysteresis = 2,
        max_scale = 2 ** 32,
        max_norm = 1.0,
        norm_type = 2.0,
        verbose = False,
    ),
    "zero2" : dict(
        name="zero2",
        stage = 2,
        precision = "fp16",
        initial_scale = 2 ** 32,
        min_scale = 1,
        growth_factor = 2,
        backoff_factor = 0.5,
        growth_interval = 1000,
        hysteresis = 2,
        max_scale = 2 ** 32,
        max_norm = 1.0,
        norm_type = 2.0,
        reduce_bucket_size_in_m = 12,
        communication_dtype= None,
        overlap_communication = True,
        cpu_offload = False,
        verbose = False,
    ),
    "zero2_cpu" : dict(
        name="zero2_cpu",
        stage = 2,
        precision = "fp16",
        initial_scale = 2 ** 32,
        min_scale = 1,
        growth_factor = 2,
        backoff_factor = 0.5,
        growth_interval = 1000,
        hysteresis = 2,
        max_scale = 2 ** 32,
        max_norm = 1.0,
        norm_type = 2.0,
        reduce_bucket_size_in_m = 12,
        communication_dtype= None,
        overlap_communication = True,
        cpu_offload=True,
        verbose = False,
    ),
    "3d": dict(
        name="3d",
        tp_size =1,
        pp_size = 1,
        precision = "fp16",
        zero_stage = 0,
        enable_all_optimization = False,
        enable_fused_normalization = False,
        enable_flash_attention = False,
        enable_jit_fused = False,
        enable_sequence_parallelism = False,
        enable_sequence_overlap = False,
        num_microbatches = None,
        microbatch_size = None,
        initial_scale = 2 ** 16,
        min_scale = 1,
        growth_factor = 2,
        backoff_factor = 0.5,
        growth_interval = 1000,
        hysteresis = 2,
        max_scale = 2 ** 32,
        max_norm = 0,
        broadcast_buffers = True,
        ddp_bucket_cap_mb = 25,
        find_unused_parameters = False,
        check_reduction = False,
        gradient_as_bucket_view = False,
        static_graph = False,
        zero_bucket_size_in_m = 12,
        cpu_offload = False,
        communication_dtype= None,
        overlap_communication = True,
        custom_policy = None,
    )
}